{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with Chainer\n",
    "\n",
    "[VGG](https://arxiv.org/pdf/1409.1556v6.pdf) is an architecture for deep convolution networks. In this example, we train a convolutional network to perform image classification using the CIFAR-10 dataset. CIFAR-10 consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We'll train a model on SageMaker, deploy it to Amazon SageMaker hosting, and then classify images using the deployed model.\n",
    "\n",
    "The Chainer script runs inside of a Docker container running on SageMaker. For more information about the Chainer container, see the sagemaker-chainer-containers repository and the sagemaker-python-sdk repository:\n",
    "\n",
    "* https://github.com/aws/sagemaker-chainer-containers\n",
    "* https://github.com/aws/sagemaker-python-sdk\n",
    "\n",
    "For more on Chainer, please visit the Chainer repository:\n",
    "\n",
    "* https://github.com/chainer/chainer\n",
    "\n",
    "This notebook is adapted from the [CIFAR-10](https://github.com/chainer/chainer/tree/master/examples/cifar) example in the Chainer repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Setup\n",
    "from sagemaker import get_execution_role\n",
    "import sagemaker\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "\n",
    "# This role retrieves the SageMaker-compatible role used by this Notebook Instance.\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading training and test data\n",
    "\n",
    "We use helper functions provided by `chainer` to download and preprocess the CIFAR10 data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import chainer\n",
    "\n",
    "from chainer.datasets import get_cifar10\n",
    "\n",
    "train, test = get_cifar10()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Uploading the data\n",
    "\n",
    "We save the preprocessed data to the local filesystem, and then use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value `inputs` identifies the S3 location, which we will use when we start the Training Job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "train_data = [element[0] for element in train]\n",
    "train_labels = [element[1] for element in train]\n",
    "\n",
    "test_data = [element[0] for element in test]\n",
    "test_labels = [element[1] for element in test]\n",
    "\n",
    "\n",
    "try:\n",
    "    os.makedirs('/tmp/data/train_cifar')\n",
    "    os.makedirs('/tmp/data/test_cifar')\n",
    "    np.savez('/tmp/data/train_cifar/train.npz', data=train_data, labels=train_labels)\n",
    "    np.savez('/tmp/data/test_cifar/test.npz', data=test_data, labels=test_labels)\n",
    "    train_input = sagemaker_session.upload_data(\n",
    "                      path=os.path.join('/tmp', 'data', 'train_cifar'),\n",
    "                      key_prefix='notebook/chainer_cifar/train')\n",
    "    test_input = sagemaker_session.upload_data(\n",
    "                      path=os.path.join('/tmp', 'data', 'test_cifar'),\n",
    "                      key_prefix='notebook/chainer_cifar/test')\n",
    "finally:\n",
    "    shutil.rmtree('/tmp/data')\n",
    "print('training data at %s', train_input)\n",
    "print('test data at %s', test_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing the Chainer script to run on Amazon SageMaker\n",
    "\n",
    "### Training\n",
    "\n",
    "We need to provide a training script that can run on the SageMaker platform. The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, such as:\n",
    "\n",
    "* `SM_MODEL_DIR`: A string representing the path to the directory to write model artifacts to.\n",
    "  These artifacts are uploaded to S3 for model hosting.\n",
    "* `SM_NUM_GPUS`: An integer representing the number of GPUs available to the host.\n",
    "* `SM_OUTPUT_DIR`: A string representing the filesystem path to write output artifacts to. Output artifacts may\n",
    "  include checkpoints, graphs, and other files to save, not including model artifacts. These artifacts are compressed\n",
    "  and uploaded to S3 to the same S3 prefix as the model artifacts.\n",
    "\n",
    "Supposing two input channels, 'train' and 'test', were used in the call to the Chainer estimator's ``fit()`` method,\n",
    "the following will be set, following the format `SM_CHANNEL_[channel_name]`:\n",
    "\n",
    "* `SM_CHANNEL_TRAIN`: A string representing the path to the directory containing data in the 'train' channel\n",
    "* `SM_CHANNEL_TEST`: Same as above, but for the 'test' channel.\n",
    "\n",
    "A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to `model_dir` so that it can be hosted later. Hyperparameters are passed to your script as arguments and can be retrieved with an `argparse.ArgumentParser` instance. For example, the script run by this notebook starts with the following:\n",
    "\n",
    "```python\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "if __name__ =='__main__':\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "\n",
    "    # retrieve the hyperparameters we set from the client (with some defaults)\n",
    "    parser.add_argument('--epochs', type=int, default=50)\n",
    "    parser.add_argument('--batch-size', type=int, default=64)\n",
    "    parser.add_argument('--learning-rate', type=float, default=0.05)\n",
    "\n",
    "    # Data, model, and output directories These are required.\n",
    "    parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])\n",
    "    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])\n",
    "    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])\n",
    "    parser.add_argument('--test', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])\n",
    "    \n",
    "    args, _ = parser.parse_known_args()\n",
    "    \n",
    "    num_gpus = int(os.environ['SM_NUM_GPUS'])\n",
    "    \n",
    "    # ... load from args.train and args.test, train a model, write model to args.model_dir.\n",
    "```\n",
    "\n",
    "Because the Chainer container imports your training script, you should always put your training code in a main guard (`if __name__=='__main__':`) so that the container does not inadvertently run your training code at the wrong point in execution.\n",
    "\n",
    "For more information about training environment variables, please visit https://github.com/aws/sagemaker-containers.\n",
    "\n",
    "### Hosting and Inference\n",
    "\n",
    "We use a single script to train and host the Chainer model. You can also write separate scripts for training and hosting. In contrast with the training script, the hosting script requires you to implement functions with particular function signatures (or rely on defaults for those functions).\n",
    "\n",
    "These functions load your model, deserialize data sent by a client, obtain inferences from your hosted model, and serialize predictions back to a client:\n",
    "\n",
    "\n",
    "* **`model_fn(model_dir)` (always required for hosting)**: This function is invoked to load model artifacts from those that were written into `model_dir` during training.\n",
    "\n",
    "The script that this notebook runs uses the following `model_fn` function for hosting:\n",
    "```python\n",
    "def model_fn(model_dir):\n",
    "    chainer.config.train = False\n",
    "    model = L.Classifier(net.VGG(10))\n",
    "    serializers.load_npz(os.path.join(model_dir, 'model.npz'), model)\n",
    "    return model.predictor\n",
    "```\n",
    "* `input_fn(input_data, content_type)`: This function is invoked to deserialize prediction data when a prediction request is made. The return value is passed to predict_fn. `input_data` is the serialized input data in the body of the prediction request, and `content_type`, the MIME type of the data.\n",
    "  \n",
    "  \n",
    "* `predict_fn(input_data, model)`: This function accepts the return value of `input_fn` as the `input_data` parameter and the return value of `model_fn` as the `model` parameter and returns inferences obtained from the model.\n",
    "  \n",
    "  \n",
    "* `output_fn(prediction, accept)`: This function is invoked to serialize the return value from `predict_fn`, which is passed in as the `prediction` parameter, back to the SageMaker client in response to prediction requests.\n",
    "\n",
    "\n",
    "`model_fn` is always required, but default implementations exist for the remaining functions. These default implementations can deserialize a NumPy array, invoking the model's `__call__` method on the input data, and serialize a NumPy array back to the client.\n",
    "\n",
    "This notebook relies on the default `input_fn`, `predict_fn`, and `output_fn` implementations. See the Chainer sentiment analysis notebook for an example of how one can implement these hosting functions.\n",
    "\n",
    "Please examine the script below. Training occurs behind the main guard, which prevents the function from being run when the script is imported, and `model_fn` loads the model saved into `model_dir` during training.\n",
    "\n",
    "\n",
    "\n",
    "For more on writing Chainer scripts to run on SageMaker, or for more on the Chainer container itself, please see the following repositories: \n",
    "\n",
    "* For writing Chainer scripts to run on SageMaker: https://github.com/aws/sagemaker-python-sdk\n",
    "* For more on the Chainer container and default hosting functions: https://github.com/aws/sagemaker-chainer-containers\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!pygmentize 'src/chainer_cifar_vgg_single_machine.py'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the training script on SageMaker\n",
    "\n",
    "To train a model with a Chainer script, we construct a ```Chainer``` estimator using the [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk). We pass in an `entry_point`, the name of a script that contains a couple of functions with certain signatures (`train` and `model_fn`), and a `source_dir`, a directory containing all code to run inside the Chainer container. This script will be run on SageMaker in a container that invokes these functions to train and load Chainer models. \n",
    "\n",
    "The ```Chainer``` class allows us to run our training function as a training job on SageMaker infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. In this case we will run our training job on one `ml.p2.xlarge` instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from sagemaker.chainer.estimator import Chainer\n",
    "\n",
    "chainer_estimator = Chainer(entry_point='chainer_cifar_vgg_single_machine.py',\n",
    "                            source_dir=\"src\",\n",
    "                            role=role,\n",
    "                            sagemaker_session=sagemaker_session,\n",
    "                            train_instance_count=1,\n",
    "                            train_instance_type='ml.p2.xlarge',\n",
    "                            hyperparameters={'epochs': 50, 'batch-size': 64})\n",
    "\n",
    "chainer_estimator.fit({'train': train_input, 'test': test_input})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our Chainer script writes various artifacts, such as plots, to a directory `output_data_dir`, the contents of which which SageMaker uploads to S3. Now we download and extract these artifacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from s3_util import retrieve_output_from_s3\n",
    "\n",
    "chainer_training_job = chainer_estimator.latest_training_job.name\n",
    "\n",
    "desc = sagemaker_session.sagemaker_client. \\\n",
    "           describe_training_job(TrainingJobName=chainer_training_job)\n",
    "output_data = desc['ModelArtifacts']['S3ModelArtifacts'].replace('model.tar.gz', 'output.tar.gz')\n",
    "\n",
    "retrieve_output_from_s3(output_data, 'output/single_machine_cifar')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These plots show the accuracy and loss over epochs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "from IPython.display import display\n",
    "\n",
    "accuracy_graph = Image(filename=\"output/single_machine_cifar/accuracy.png\",\n",
    "                       width=800,\n",
    "                       height=800)\n",
    "loss_graph = Image(filename=\"output/single_machine_cifar/loss.png\",\n",
    "                   width=800,\n",
    "                   height=800)\n",
    "\n",
    "display(accuracy_graph, loss_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploying the Trained Model\n",
    "\n",
    "After training, we use the Chainer estimator object to create and deploy a hosted prediction endpoint. We can use a CPU-based instance for inference (in this case an `ml.m4.xlarge`), even though we trained on GPU instances.\n",
    "\n",
    "The predictor object returned by `deploy` lets us call the new endpoint and perform inference on our sample images. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "predictor = chainer_estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CIFAR10 sample images\n",
    "\n",
    "We'll use these CIFAR10 sample images to test the service:\n",
    "\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/airplane1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/automobile1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/bird1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/cat1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/deer1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/dog1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/frog1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/horse1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/ship1.png\" />\n",
    "<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/truck1.png\" />\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predicting using SageMaker Endpoint\n",
    "\n",
    "We batch the images together into a single NumPy array to obtain multiple inferences with a single prediction request."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from skimage import io\n",
    "import numpy as np\n",
    "\n",
    "def read_image(filename):\n",
    "    img = io.imread(filename)\n",
    "    img = np.array(img).transpose(2, 0, 1)\n",
    "    img = np.expand_dims(img, axis=0)\n",
    "    img = img.astype(np.float32)\n",
    "    img *= 1. / 255.\n",
    "    img = img.reshape(3, 32, 32)\n",
    "    return img\n",
    "\n",
    "\n",
    "def read_images(filenames):\n",
    "    return np.array([read_image(f) for f in filenames])\n",
    "\n",
    "filenames = ['images/airplane1.png',\n",
    "             'images/automobile1.png',\n",
    "             'images/bird1.png',\n",
    "             'images/cat1.png',\n",
    "             'images/deer1.png',\n",
    "             'images/dog1.png',\n",
    "             'images/frog1.png',\n",
    "             'images/horse1.png',\n",
    "             'images/ship1.png',\n",
    "             'images/truck1.png']\n",
    "\n",
    "image_data = read_images(filenames)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The predictor runs inference on our input data and returns a list of predictions whose argmax gives the predicted label of the input data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "response = predictor.predict(image_data)\n",
    "\n",
    "for i, prediction in enumerate(response):\n",
    "    print('image {}: prediction: {}'.format(i, prediction.argmax(axis=0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup\n",
    "\n",
    "After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "chainer_estimator.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_chainer_p36",
   "language": "python",
   "name": "conda_chainer_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
